import numpy as np
import torch

from ModularUtils.FunctionsConstant import asKey, getdoKey
from ModularUtils.FunctionsTraining import save_results
from ModularUtils.FunctionsDistribution import get_joint_distributions_from_samples, calculate_TVD, calculate_KL, \
    match_with_true_dist
from ModularUtils.ControllerConstants import map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.Functions_Plot_Results import plot_saved_results
from Sachs_experiment.GroundTruth.CausalGraph_Sachs import set_sachs_nonId_graph


def get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff):
    feat="feature"

    fake_expected_dist={}
    true_expected_dist = {}

    # P(Akt|Erk)
    all_true_dist=[{tuple([0]): 0.6693, tuple([1]): 0.3305, tuple([2]): 0.0002},
                   {tuple([0]): 0.7409, tuple([1]): 0.2588, tuple([2]): 0.0003},
                   {tuple([0]): 0.1420, tuple([1]): 0.6805, tuple([2]): 0.1775}]
    for query in Exp.interv_queries:

        if bool(set(query["obs"]) & set(cur_mechs)) ==False:
            continue


        tvd_sum = 0
        kl_sum = 0
        for id, intv_key in enumerate(query["intervs"]):

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, query["obs"], Exp.Synthetic_Sample_Size)
            generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, query["obs"])

            true_dist= all_true_dist[id]
            obs_tvd, obs_kl, true_dist, fake_dist = match_with_true_dist(Exp, query["obs"], generated_labels_full, true_dist, feat, doPrint=False)  # get it from scm

            print(f'{intv_key}: tvd:{obs_tvd}, kl:{obs_kl} and tvd<={np.sqrt(0.5 * obs_kl)}')

            # tvd_sum += obs_tvd * obs_dist[tuple(intv_key.values())]
            tvd_sum += obs_tvd * 1/len(query["intervs"])
            kl_sum += obs_kl * 1/len(query["intervs"])


        print(f'--->Average tvd:{tvd_sum}, kl:{kl_sum} and tvd<={np.sqrt(0.5 * kl_sum)}')
        tvd_diff[query["expr"]].append(round(tvd_sum, 4))
        kl_diff[query["expr"]].append(round(kl_sum, 4))


    return tvd_diff, kl_diff, true_expected_dist, fake_expected_dist



def sachsEvaluation(Exp, cur_mechs, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels = {}
        all_real_labels = {}


        for query in Exp.interv_queries:

            for key in query["intervs"]:

        # for interv_no, key in enumerate(Exp.Data_intervs):
                compare_Var = query["obs"]
                intv_key = asKey(key)
                generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, dict(intv_key), compare_Var , Exp.Synthetic_Sample_Size)
                generated_labels_full= map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)
                all_generated_labels[intv_key] = torch.tensor(generated_labels_full)
                fake_dist = get_joint_distributions_from_samples(Exp, compare_Var, all_generated_labels[intv_key].detach().cpu().numpy().astype(int), "feature")

                if dict(intv_key) == {}:
                    obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]
                    current_real_label = []
                    if intv_key in dataset_dict:
                        current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)
                    all_real_labels[intv_key] = torch.tensor(current_real_label)

                    # tvd_diff, kl_diff= compare_conditionals(Exp, label_generators, all_real_labels[dict(intv_key)], cur_mechs, tvd_diff, kl_diff)

                    real_dist = get_joint_distributions_from_samples(Exp, compare_Var, all_real_labels[intv_key].detach().cpu().numpy().astype(int), "feature")

                    # print(f'P({query["obs"]}|do({key}) : {fake_dist}')

                elif query["expr"]=="P(Mek|do[PKA=2])":
                    real_dist={(0,): 1.0, (1,): 1e-06, (2,): 1e-06}
                elif query["expr"]=="P(Erk|do[PKA=2])":
                    real_dist={(0,): 0.07460035523978685, (1,): 0.7069271758436945, (2,): 0.21847246891651864}
                elif query["expr"]=="P(Akt|do[PKA=2])":
                    real_dist= {(0,): 0.8046181172291297, (1,): 0.19538188277087035, (2,): 1e-06}


                tvd = calculate_TVD(real_dist, fake_dist, doPrint=False)
                kl = calculate_KL(real_dist, fake_dist, doPrint=False)

                query_str= getdoKey(compare_Var, dict(intv_key))
                if query_str in tvd_diff:
                    tvd_diff[query_str].append(round(tvd,4))
                    kl_diff[query_str].append(round(kl,4))


                # if interv_no == 1:
                #     compare_Var = ["Mek"]









        # tvd_diff, kl_diff, _, _ = get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff,
        #                                                               kl_diff)

            # interventional queries evaluation
        # for query in Exp.interv_queries:
        #     if cur_mech == query["obs"][-1]:
        #         intv_tvd = compare_interventions(Exp, label_generators, query["obs"], query["interv"],doPrint=False)
        #         tvd_diff[query["expr"]].append(round(intv_tvd * 100, 4))


        save_results(Exp, Exp.SAVED_PATH, all_generated_labels, all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)

        # counterfactual queries evaluation
        # cf_samples = get_cf_samples(Exp, label_generators, Exp.cf_observe, Exp.cf_intervene, Exp.cf_evidence)
        # true_cf = get_cf_dist(Exp)
        # diff3 = match_with_true_dist(Exp.cf_observe, cf_samples, true_cf)  #get it from scm
        # print("Counterfactual difference", diff3)
        # all_diff["cf"].append(round(diff3*100,4))

    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:", tvd_diff[dist][ll:])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff



Exp = Experiment("Exp1", set_sachs_nonId_graph ,
                 new_experiment=False,
                 features=["feature"],
                 Data_intervs=[{}])

#
# plot_saved_results(Exp, "/path_to_project/SAVED_EXPERIMENTS/sachs_nonId_graph/Exp1/Apr_29_2023-14_23",
#                    2000, delta=20,
#                    pre_labels = ["$P(V)$", "$ncmP(V)$", "$P(Mek|do[PKA=2])$", "$ncmP(Mek|do[PKA=2])$", '$P(Akt|do[PKA=2])$', '$ncmP(Akt|do[PKA=2])$'],
#                    benchmarks=('ncm',"/path_to_project/SAVED_EXPERIMENTS/sachs_nonId_graph/Exp1/Apr_29_2023-03_11"))



root = f"/path_to_project/SAVED_EXPERIMENTS/sachs_nonId_graph/Exp1"
exp='Dec_05_2022-03_28'
bnc_exp=[]
# pre_labels= ['$P(D,A)$', '$ncmP(D,A)$', 'rep$P(D,A)$',
#              'P(A|do(D=0))', 'ncmP(A|do(D=0))', 'repP(A|do(D=0))',
#              'P(A|do(D=1))', 'ncmP(A|do(D=1))', 'repP(A|do(D=1))']

last_exp= f"{root}/{exp}"
# benchmarks=[('ncm', f'{root}/{bnc_exp[0]}')]
benchmarks=[]

plot_saved_results(Exp, last_exp, epochs=1000, delta=10,
               pre_labels=None, benchmarks=benchmarks)  #only whatifgan